import copy
import json
import os
import pickle
from math import floor

from PIL import Image
from matplotlib import cm

import numpy as np
import torch


# ara= torch.randn(10)*4
# print(ara)
# ret= map_to_exact_discrete(ara)
# print(ret)
from matplotlib import pyplot as plt
from numpy import uint8

from ModularUtils.ControllerModel import get_generators, get_generated_labels
from ModularUtils.DigitImageGeneration.mnist_image_generation import plot_trained_digits
from ModularUtils.Experiment_Class import Experiment

Exp = Experiment("Exp1", set_nonid_mnist_images,
                 dist_thresh=0.15,
                 causal_hierarchy=2,
                 Temperature=1,
                 temp_min=0.01,
                 NOISE_DIM=128,
                 CONF_NOISE_DIM=128,
                 G_hid_dims=[256, 256],
                 D_hid_dims=[256, 256, 256],
                 IMAGE_FILTERS=[128, 64, 32],
                 CRITIC_ITERATIONS=1,
                 LAMBDA_GP=10,
                 learning_rate=2 * 1e-4,
                 Synthetic_Sample_Size=40000,
                 intv_Sample_Size=40000,
                 batch_size=200,
                 features=["feature"],
                 noise_states=100,
                 latent_state=16,
                 # Data_intervs=[{}, {"X1": 0}, {"X1": 1}],
                 Data_intervs=[{}],
                 num_epochs=300,
                 new_experiment=False
                 )


Exp.Synthetic_Sample_Size = 40000
Exp.intv_batch_size = Exp.batch_size

Exp.true_bn = {}


SHARED_INFO = "SAVED_EXPERIMENTS/"+Exp.Complete_DAG_desc+"/SHARED_INFO.txt"
with open(SHARED_INFO) as f:
    data = f.read()
INSTANCE = json.loads(data)

last_exp = INSTANCE["last_exp"]
# last_exp ="/path_to_project/SAVED_EXPERIMENTS/nonid_mnist_images/Exp1/Sep_23_2022-10_45"
print(last_exp)
Exp.LOAD_MODEL_PATH = last_exp

load_which_models = {"X1": False, "X2": False, "W": False, "Ydigit1": False, "Ydigit2": False, "Ycolor": False,
                         "Ythick": False,
                         "ImgYdigit1": True, "ImgYdigit2": False}
cur_mechs = ["ImgYdigit1"]




label_generators, optimizersMech = get_generators(Exp, load_which_models)


for gen in label_generators:
    label_generators[gen].eval()


with torch.no_grad():
    # file_name ="/path_to_project/SAVED_EXPERIMENTS/nonid_mnist_images/labels/(0, 1)P(V|X1,X2).txt"
    # if needed use these
    file_name ="SAVED_EXPERIMENTS/nonid_mnist_images/labels/(1, 4)P(V|X1,X2).txt"
    # file_name ="/path_to_project/SAVED_EXPERIMENTS/nonid_mnist_images/labels/(1, 4)P(V|do(X1,X2)).txt"
    # file_name ="/path_to_project/SAVED_EXPERIMENTS/nonid_mnist_images/labels/14|01P(Y|do(X1,X2),X1p,X2p).txt"  #intv|evidence

    # file_name ="/path_to_project/SAVED_EXPERIMENTS/nonid_mnist_images/labels/15|16P(Y|do(X1,X2),X1p,X2p).txt"

    # file_name ="/path_to_project/SAVED_EXPERIMENTS/nonid_mnist_images/labels/(0, 8)P(V|do(X1,X2)).txt"
    # file_name = "/path_to_project/SAVED_EXPERIMENTS/nonid_mnist_images/labels/(0, 8)P(V|X1,X2).txt"
    # file_name ="/path_to_project/SAVED_EXPERIMENTS/nonid_mnist_images/labels/(0, 1)P(Y|do(X1,X2),X1p,X2p).txt"

    with open(file_name) as f:
        data = f.read()
    cond_result = json.loads(data)

    obs_combs= cond_result["obs_comb"][0:10]
    loss= cond_result["loss"]

    row_images ={0: [], 1:[]}
    for idx, comb in enumerate(obs_combs):
        row_id= floor(2*idx/len(obs_combs))

        x1,x2,w,y1,y2,col,thick= comb
        print("keys",x1,x2, "img:", y1,y2,col,thick, " prob:",cond_result["prob"][idx], " loss:",cond_result["loss"][idx] )

        intv_pro= {"Ydigit1":y1, "Ycolor":col, "Ythick":thick}
        generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, intv_pro, cur_mechs, 1, hard=True)
        image1 = generated_labels_dict[Exp.image_labels[0]]

        intv_pro = {"Ydigit1": y2, "Ycolor": col, "Ythick": thick}
        generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, intv_pro, cur_mechs, 1,hard=True)
        image2 = generated_labels_dict[Exp.image_labels[0]]

        generated_image= torch.cat([image1, image2], dim=0)

        colors= ["Red", "Green", "Blue"]
        # lst_title=[
        #     f"X1:{x1}, X2:{x2}, image1=({x1}+{x2})^2/10, color:{colors[col]}, thickness:{thick}",
        #     f"X1:{x1}, X2:{x2}, image2=({x1}+{x2})^2%10, loss:{round(loss[idx],3)}",
        # ]

        lst_title=[
            # f" color:X2%3, thick:W^2/10, loss:{round(loss[idx], 3)}"]
            f" loss:{round(loss[idx], 3)}"]


        cur_image= []
        for genimg in generated_image:
            genimg = genimg.permute(1, 2, 0).detach().cpu().numpy()
            cur_image.append(genimg)


        row_images[row_id].append(np.concatenate(cur_image, axis=1))

            # PIL_image = Image.fromarray(np.uint8(genimg)).convert('RGB')
            # lst_image.append(PIL_image)


        # widths, heights = zip(*(i.size for i in lst_image))
        # total_width = sum(widths)
        # max_height = max(heights)
        # new_im = Image.new('RGB', (total_width, max_height))
        # x_offset = 0
        # for im in lst_image:
        #     new_im.paste(im, (x_offset, 0))
        #     x_offset += im.size[0]


        # I = np.asarray(new_im)
        # plot_trained_digits(1, 1, [I], ["working"])

        white = np.ones((32, 3, 3)) * 255
        row_images[row_id].append(white)


    img_row1= np.concatenate(row_images[0], axis=1)
    img_row2= np.concatenate(row_images[1], axis=1)

    # white= np.ones((3, img_row1.shape[1],3))*255
    # final_image = np.concatenate([img_row1, white, img_row2])
    final_image = np.concatenate([img_row1, img_row2], axis=1)
    # plot_trained_digits(1, 1, [final_image], ["P(Image|do(X1,X2))"])
    plot_trained_digits(1, 1, [final_image], [])










